import collections
import math
from abc import ABCMeta

from river import base, utils


class DenStream(base.Clusterer):
    r"""DenStream

    DenStream [^1] is a clustering algorithm for evolving data streams.
    DenStream can discover clusters with arbitrary shape and is robust against
    noise (outliers).

    "Dense" micro-clusters (named core-micro-clusters) summarise the clusters
    of arbitrary shape. A pruning strategy based on the concepts of potential
    and outlier micro-clusters guarantees the precision of the weights of the
    micro-clusters with limited memory.

    The algorithm is divided into two parts:

    **Online micro-cluster maintenance (learning)**

    For a new point `p`:

    * Try to merge `p` into either the nearest `p-micro-cluster` (potential),
    `o-micro-cluster` (outlier), or create a new `o-micro-cluster` and insert it
    into the outlier buffer.

    * For each `T_p` iterations, consider the weights of all potential and
    outlier micro-clusters. If their weights are smaller than a certain
    threshold (different for each type of micro-clusters), the micro-cluster is
    deleted.

    **Offline generation of clusters on-demand (clustering)**

    A variant of the DBSCAN algorithm [^2] is used, such that all
    density-connected p-micro-clusters determine the final clusters.

    Parameters
    ----------
    decaying_factor
        Parameter that controls the importance of historical data to current cluster.
        Note that `decaying_factor` has to be different from `0`.

    core_weight_threshold
        Parameter to determine the threshold of outlier relative to core micro-clusters.
        Note that `core_weight_threshold * tolerance_factor` has to be greater than `1` or
        less than `0`.

    tolerance_factor
        Parameter to determine the threshold of outliers relative to core micro-cluster.
        In a normal setting, this parameter is usuallly set within the range `[0,1]`.
        Once again, note that `core_weight_threshold * tolerance_factor` has to be greater than `1`
        or less than `0`.

    radius
        This parameter is passed onto the DBSCAN offline algorithm as the $\epsilon$ parameter
        when a clustering request arrives.

    Attributes
    ----------
    n_clusters
        Number of clusters generated by the algorithm.

    clusters
        A set of final clusters of type `MicroCluster`, which means that these cluster include all the required
        information, including number of points, creation time, weight, (weighted) linear sum, (weighted) square sum,
        center and radius.

    p_micro_clusters
        The p micro-clusters that are generated by the algorithm. When a generating cluster request arrives,
        these p-micro-clusters will go through a variant of DBSCAN algorithm to determine the final clusters.

    o_micro_clusters
        The outlier buffer, separating the processing of the potential core-micro-cluster and outlier-micro-clusters.

    References
    ----------
    [^1]: Feng et al (2006, pp 328-339). Density-Based Clustering over an Evolving Data Stream with Noise.
          In Proceedings of the Sixth SIAM International Conference on Data Mining,
          April 20–22, 2006, Bethesda, MD, USA.
    [^2]: Ester et al (1996). A Density-Based Algorithm for Discovering Clusters in Large Spatial Databases
          with Noise. In KDD-96 Proceedings, AAAI.

    Examples
    ----------

    The following example uses the default parameters of the algorithm to test its functionality. It can easily be seen
    that the set of evolving points `X` are designed so that there can be a clear picture drawn on how the clusters can
    be generated.

    >>> from river import cluster
    >>> from river import stream

    >>> X = [
    ...     [-1, -0.5], [-1, -0.625], [-1, -0.75], [-1, -1], [-1, -1.125], [-1, -1.25],
    ...     [-1.5, -0.5], [-1.5, -0.625], [-1.5, -0.75], [-1.5, -1], [-1.5, -1.125], [-1.5, -1.25],
    ...     [1, 1.5], [1, 1.75], [1, 2], [4, 1.25], [4, 1.5], [4, 2.25],
    ...     [4, 2.5], [4, 3], [4, 3.25], [4, 3.5], [4, 3.75], [4, 4],
    ... ]

    >>> denstream = cluster.DenStream(decaying_factor = 0.01,
    ...                               core_weight_threshold = 1.01,
    ...                               tolerance_factor = 1.0005,
    ...                               radius = 0.5)

    >>> for x, _ in stream.iter_array(X):
    ...     denstream = denstream.learn_one(x)

    >>> denstream.predict_one({0: -1, 1: -2})
    0

    >>> denstream.predict_one({0:5, 1:4})
    1

    >>> denstream.n_clusters
    2

    """

    def __init__(
        self,
        decaying_factor: float = 0.25,
        core_weight_threshold: float = 5,
        tolerance_factor: float = 0.5,
        radius: float = 2,
    ):
        super().__init__()
        self.time_stamp = -1
        self.initialized = False
        self.decaying_factor = decaying_factor
        self.core_weight_threshold = core_weight_threshold
        self.tolerance_factor = tolerance_factor
        self.radius = radius

        # number of clusters generated by applying the variant of DBSCAN algorithm
        # on p micro-cluster centers and their centers
        self.n_clusters = 0
        self.clusters = {}
        self.centers = {}
        self.p_micro_clusters = {}
        self.o_micro_clusters = {}

    @property
    def _time_point(self):
        return math.ceil(
            1
            / self.decaying_factor
            * math.log(
                self.tolerance_factor
                * self.core_weight_threshold
                / (self.tolerance_factor * self.core_weight_threshold - 1)
            )
        )

    @staticmethod
    def _distance(point_a, point_b):
        return math.sqrt(utils.math.minkowski_distance(point_a, point_b, 2))

    def _find_closest_cluster_index(self, point, micro_clusters):
        min_distance = math.inf
        closest_cluster_index = -1
        for i, micro_cluster_i in micro_clusters.items():
            distance = self._distance(micro_cluster_i.center, point)
            if distance < min_distance:
                min_distance = distance
                closest_cluster_index = i
        return closest_cluster_index

    def _merge(self, point):
        # initiate merged status
        merged_status = False
        # create a new micro-cluster from point p
        mc_from_p = DenStreamMicroCluster(
            x=point,
            timestamp=self.time_stamp,
            decaying_factor=self.decaying_factor,
            current_time=self.time_stamp,
        )

        if len(self.p_micro_clusters) != 0:
            # try to merge p into its nearest p-micro-cluster c_p
            closest_pmc_index = self._find_closest_cluster_index(
                point, self.p_micro_clusters
            )
            new_pmc = self.p_micro_clusters[closest_pmc_index]
            new_pmc.add(mc_from_p)
            if new_pmc.radius <= self.radius:
                # merge p into nearest c_p
                self.p_micro_clusters[closest_pmc_index].add(mc_from_p)
                merged_status = True

        if not merged_status:
            if len(self.o_micro_clusters) != 0:
                closest_omc_index = self._find_closest_cluster_index(
                    point, self.o_micro_clusters
                )
                new_omc = self.o_micro_clusters[closest_omc_index]
                new_omc.add(mc_from_p)
                if new_omc.radius <= self.radius:
                    # merge p into nearest c_0
                    self.o_micro_clusters[closest_omc_index].add(mc_from_p)
                    if (
                        self.o_micro_clusters[closest_omc_index].weight
                        > self.tolerance_factor * self.core_weight_threshold
                    ):
                        # remove c_o from outlier-buffer
                        self.o_micro_clusters.pop(closest_omc_index)
                        # add a new p_micro_cluster by c_o
                        self.p_micro_clusters[len(self.p_micro_clusters)] = new_omc
                else:
                    # create a new o-micro-cluster by p and add it to o_micro_clusters
                    self.o_micro_clusters[len(self.o_micro_clusters)] = mc_from_p

    def _is_directly_density_reachable(self, c_p, c_q):
        # if c_p is directly reachable from c_q, weight of c_q > mu, and vice versa.
        # for two clusters to be connected, they have to be density reachable from a third cluster. hence check
        # check weight of two clusters
        if (
            c_p.weight > self.core_weight_threshold
            and c_q.weight > self.core_weight_threshold
        ):
            # check distance of two clusters and compare with 2*eps
            if self._distance(c_p.center, c_q.center) < 2 * self.radius:
                # further check that the distance is smaller than sum of radius
                if self._distance(c_p.center, c_q.center) < c_p.radius + c_q.radius:
                    return True
        return False

    def _query_neighbor(self, cluster):
        neighbors = {}
        # scan all clusters within self.p_micro_clusters
        for pmc in self.p_micro_clusters.values():
            # check density reachable and check that the cluster itself does not appear in neighbors
            if self._is_directly_density_reachable(cluster, pmc) and cluster != pmc:
                neighbors[pmc] = None
        return neighbors

    @staticmethod
    def _generate_clusters_from_labels(cluster_labels):
        # initiate the set for final clusters
        clusters = {}

        # generate set of clusters with the same label with the structure {j: p-micro-cluster}
        for i in range(max(cluster_labels.values()) + 1):
            j = 0
            pmcs_with_label_i = {}
            for pmc, label_pmc in cluster_labels.items():
                if label_pmc == i:
                    pmcs_with_label_i[j] = pmc
                    j += 1

            # generate a final big cluster from clusters with the same label using the add function in MicroCluster
            cluster = pmcs_with_label_i[0]
            for m in range(1, len(pmcs_with_label_i)):
                cluster.add(pmcs_with_label_i[m])

            clusters[i] = cluster

        n_clusters = len(clusters)

        return n_clusters, clusters

    def learn_one(self, x, sample_weight=None):

        self.time_stamp += 1

        # initialize
        if not self.initialized:
            mc_from_x = DenStreamMicroCluster(
                x=x,
                timestamp=self.time_stamp,
                decaying_factor=self.decaying_factor,
                current_time=self.time_stamp,
            )
            if mc_from_x.weight >= self.tolerance_factor * self.core_weight_threshold:
                self.p_micro_clusters[0] = mc_from_x
            else:
                self.o_micro_clusters[0] = mc_from_x
            self.initialized = True
            return self

        # update current_time of all micro-clusters
        for p_micro_cluster in self.p_micro_clusters.values():
            p_micro_cluster.current_time = self.time_stamp
        for o_micro_cluster in self.o_micro_clusters.values():
            o_micro_cluster.current_time = self.time_stamp

        self._merge(x)

        if self.time_stamp % self._time_point == 0:
            for i, p_micro_cluster_i in list(self.p_micro_clusters.items()):
                if (
                    p_micro_cluster_i.weight
                    < self.tolerance_factor * self.core_weight_threshold
                ):
                    # delete c_p
                    self.p_micro_clusters.pop(i)
            for j, o_micro_cluster_j in list(self.o_micro_clusters.items()):
                # calculate xi
                xi = (
                    2
                    ** (
                        -self.decaying_factor
                        * (
                            self.time_stamp
                            - o_micro_cluster_j.creation_time
                            + self._time_point
                        )
                    )
                    - 1
                ) / (2 ** (-self.decaying_factor * self._time_point) - 1)
                if o_micro_cluster_j.weight < xi:
                    # delete c_o
                    self.o_micro_clusters.pop(j)
        return self

    def predict_one(self, x, sample_weight=None):

        # This function handles the case when a clustering request arrives.
        # implementation of the DBSCAN algorithm proposed by Ester et al.

        # initiate labels of p-micro-clusters to None
        labels = {pmc: None for pmc in self.p_micro_clusters.values()}

        # cluster counter; in this algorithm cluster labels start with 0
        c = -1

        for pmc in self.p_micro_clusters.values():
            # previously processed in inner loop
            if labels[pmc] is not None:
                continue
            pmc_neighbors = self._query_neighbor(pmc)
            # no label as noise, as there are no min points
            # next cluster label
            c += 1
            labels[pmc] = c
            # neighbors to expand has already been generated by query_neighbor
            # which means that pmc_neighbors is already a seed set
            seed_set = collections.deque(pmc_neighbors.keys())
            # process every point in seed set
            while seed_set:
                # check previously proceeded points
                if labels[seed_set[0]] is not None:
                    seed_set.popleft()
                    continue
                if seed_set:
                    labels[seed_set[0]] = c
                    # find neighbors
                    neighbor_neighbors = collections.deque(
                        self._query_neighbor(seed_set[0]).keys()
                    )
                    # add new neighbors to seed set
                    for neighbor_neighbor in neighbor_neighbors:
                        if labels[neighbor_neighbor] is not None:
                            seed_set.append(neighbor_neighbor)

        self.n_clusters, self.clusters = self._generate_clusters_from_labels(labels)

        y = self._find_closest_cluster_index(x, self.clusters)

        self.centers = {i: self.clusters[i].center for i in self.clusters.keys()}

        return y


class DenStreamMicroCluster(metaclass=ABCMeta):
    """ DenStream Micro-cluster class """

    def __init__(self, x=None, timestamp=None, decaying_factor=None, current_time=None):

        self.x = x
        self.timestamp = timestamp
        self.creation_time = timestamp
        self.decaying_factor = decaying_factor
        self.current_time = current_time

        if x is not None and timestamp is not None:

            # initial weight, initial weighted linear sum (IWLS) and initial weighted squared sum (IWSS)
            # are calculated as described

            self.N = 1
            self.dim = len(self.x)
            self.initial_weight = 2 ** (self.decaying_factor * self.timestamp)
            self.LS = x
            self.SS = {i: (x[i] * x[i]) for i in range(self.dim)}
            self.IWLS = {
                i: (2 ** (self.decaying_factor * self.timestamp)) * x[i]
                for i in range(self.dim)
            }
            self.IWSS = {
                i: (2 ** (self.decaying_factor * self.timestamp)) * (x[i] * x[i])
                for i in range(self.dim)
            }

    @property
    def weighted_linear_sum(self):
        weighted_linear_sum = {
            i: (2 ** (-self.current_time * self.decaying_factor)) * self.IWLS[i]
            for i in range(self.dim)
        }
        return weighted_linear_sum

    @property
    def weighted_squared_sum(self):
        weighted_squared_sum = {
            i: (2 ** (-self.current_time * self.decaying_factor)) * self.IWSS[i]
            for i in range(self.dim)
        }
        return weighted_squared_sum

    @property
    def weight(self):
        weight = (
            2 ** (-self.current_time * self.decaying_factor)
        ) * self.initial_weight
        return weight

    @property
    def center(self):
        center = {i: self.weighted_linear_sum[i] / self.weight for i in range(self.dim)}
        return center

    @property
    def radius(self):
        radius = math.sqrt(
            abs(
                self.norm(self.weighted_squared_sum) / self.weight
                - (self.norm(self.weighted_linear_sum) / self.weight) ** 2
            )
        )
        return radius

    def add(self, cluster):
        assert self.dim == cluster.dim
        self.N += cluster.N
        self.initial_weight += cluster.initial_weight
        for i in range(self.dim):
            self.LS[i] += cluster.LS[i]
            self.SS[i] += cluster.SS[i]
            self.IWLS[i] += cluster.IWLS[i]
            self.IWSS[i] += cluster.IWSS[i]
        if self.creation_time > cluster.creation_time:
            self.creation_time = cluster.creation_time

    @staticmethod
    def norm(x):
        norm = 0
        for val in x.values():
            norm += val * val
        return math.sqrt(norm)
